{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Batch Predictions\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Amazon SageMaker Batch Transform allows you to make predictions on batches of data in S3 without setting up a REST endpoint. Batch predictions are also called “offline” predictions since they do not require an online REST endpoint. Typically meant for higher-throughput workloads that can tolerate higher latency and lower freshness, batch prediction servers typically do not run 24 hours per day like real-time prediction servers. They run for a few hours on a batch of data, then shut down - hence the term, “batch.” Batch Transform manages all of the resources needed to perform the inferences including the launch and termination of the cluster after the job completes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](img/batch_transform_tensorflow.gif)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess   = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name='sagemaker', region_name=region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Batch Transform Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%store -r training_job_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(training_job_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp s3://$bucket/$training_job_name/output/model.tar.gz ./model.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!mkdir -p ./model/\n",
    "!tar -xvzf ./model.tar.gz -C ./model/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!saved_model_cli show --all --dir ./model/tensorflow/saved_model/0/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pygmentize ./src_batch_tsv/inference.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configure TensorFlow Serving for Batch Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.tensorflow.serving import Model\n",
    "\n",
    "batch_env = {\n",
    "  # Configures whether to enable record batching.\n",
    "  'SAGEMAKER_TFS_ENABLE_BATCHING': 'true',\n",
    "\n",
    "  # Name of the model - this is important in multi-model deployments\n",
    "  'SAGEMAKER_TFS_DEFAULT_MODEL_NAME': 'saved_model',\n",
    "\n",
    "  # Configures how long to wait for a full batch, in microseconds.\n",
    "  'SAGEMAKER_TFS_BATCH_TIMEOUT_MICROS': '50000', # microseconds\n",
    "\n",
    "  # Corresponds to \"max_batch_size\" in TensorFlow Serving.\n",
    "  'SAGEMAKER_TFS_MAX_BATCH_SIZE': '10000',\n",
    "\n",
    "  # Number of seconds for the SageMaker web server timeout\n",
    "  'SAGEMAKER_MODEL_SERVER_TIMEOUT': '7200', # Seconds\n",
    "\n",
    "  # Configures number of batches that can be enqueued.\n",
    "  'SAGEMAKER_TFS_MAX_ENQUEUED_BATCHES': '10000'\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configure the Parallelism and Payload Size\n",
    "To increase performance, you can increase the max_concurrent_transforms parameter.  Tune this on a single instance before trying to scale out the number of instances - especially if you have a small file count, the multiple instances can be a big waste.  Note that `max_concurrent_transforms * max_payload <= 100`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_concurrent_transforms=1\n",
    "max_payload=1 # Megabytes (not number of records)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Instance Type and Instance Count for Our Cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_type='ml.c5.18xlarge'\n",
    "instance_count=1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Input Data and Configuration\n",
    "This include Single vs. MultiRecord, compression_type, accept_type, content_type, split types, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "strategy='MultiRecord'\n",
    "compression_type='Gzip'\n",
    "accept_type='text/csv'\n",
    "content_type='text/csv'\n",
    "assemble_with='Line'\n",
    "split_type='Line'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r raw_input_data_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(raw_input_data_s3_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 ls $raw_input_data_s3_uri"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Batch Transformer \n",
    "We are using a previously-trained model specified at `model_s3_uri`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_s3_uri = 's3://{}/{}/output/model.tar.gz'.format(bucket, training_job_name)\n",
    "\n",
    "batch_model = Model(entry_point='inference.py',\n",
    "                    source_dir='src_batch_tsv',       \n",
    "                    model_data=model_s3_uri,\n",
    "                    role=role,\n",
    "                    framework_version='2.1.0',\n",
    "                    env=batch_env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "batch_predictor = batch_model.transformer(strategy=strategy, \n",
    "                                          instance_type=instance_type,\n",
    "                                          instance_count=instance_count,\n",
    "                                          accept=accept_type,\n",
    "                                          assemble_with=assemble_with,\n",
    "                                          max_concurrent_transforms=max_concurrent_transforms,\n",
    "                                          max_payload=max_payload, # This is in Megabytes (not number of records)\n",
    "                                          env=batch_env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start Batch Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_predictor.transform(data=raw_input_data_s3_uri,\n",
    "                          split_type=split_type,\n",
    "                          compression_type=compression_type,\n",
    "                          content_type=content_type,\n",
    "#                          join_source='Input', # Need to fix \"Mismatched line count between input and output\" error\n",
    "                          experiment_config=None,\n",
    "                          wait=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/sagemaker/home?region={}#/transform-jobs/{}?region={}&tab=Monitor\">Batch Prediction Job</a></b>'.format(region, batch_predictor.latest_transform_job.job_name, region)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TransformJobs;prefix={};streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a></b>'.format(region, batch_predictor.latest_transform_job.job_name)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/s3/buckets/{}/{}/?region={}\">Batch Prediction S3 Output</a></b>'.format(bucket, batch_predictor.latest_transform_job.job_name, region)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Waiting for batch prediction job: ' + batch_predictor.latest_transform_job.job_name)\n",
    "\n",
    "batch_predictor.wait(logs=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# _Wait Until the ^^ Batch Transform Job ^^ Completes_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check Output Data\n",
    "\n",
    "After the transform job has completed, download the output data from S3.\n",
    "\n",
    "For each file in the input data, we have a corresponding file with a \".out\" extension.  This .out file contains the predicted labels for each input row. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_prediction_output_s3_uri = batch_predictor.output_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp --recursive $batch_prediction_output_s3_uri/ batch_prediction_output/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls batch_prediction_output/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "Jupyter.notebook.save_checkpoint();\n",
    "Jupyter.notebook.session.delete();"
   ]
  }
 ],
 "metadata": {
  "history": [],
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
